from torch import nn

from models.resnet import ConvBlock, IdentityBlock
from utils.transporter_utils import preprocess


class DownSample(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.layers = self._make_layers()

    def _make_layers(self):
        dim = self.out_dim
        layers = nn.Sequential(
            # conv1
            nn.Conv2d(self.in_dim, 64, stride=1, kernel_size=3, padding=1),
            nn.ReLU(True),

            # fcn
            ConvBlock(64, [64, 64, 64], kernel_size=3, stride=2),
            IdentityBlock(64, [64, 64, 64], kernel_size=3, stride=1),

            ConvBlock(64, [128, 128, 128], kernel_size=3, stride=2),
            IdentityBlock(128, [128, 128, 128], kernel_size=3, stride=1),

            ConvBlock(128, [256, 256, 256], kernel_size=3, stride=2),
            IdentityBlock(256, [256, 256, 256], kernel_size=3, stride=1),

            ConvBlock(256, [dim, dim, dim], kernel_size=3, stride=2),
            IdentityBlock(dim, [dim, dim, dim], kernel_size=3)
        )
        return layers

    def forward(self, x):
        x = preprocess(x, dist='transporter')
        out = self.layers(x)
        return out


class UpSample(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.layers = self._make_layers()

    def _make_layers(self):
        layers = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            # head
            ConvBlock(self.in_dim, [256, 256, 256], kernel_size=3, stride=1),
            IdentityBlock(256, [256, 256, 256], kernel_size=3, stride=1),
            nn.UpsamplingBilinear2d(scale_factor=2),

            ConvBlock(256, [128, 128, 128], kernel_size=3, stride=1),
            IdentityBlock(128, [128, 128, 128], kernel_size=3, stride=1),
            nn.UpsamplingBilinear2d(scale_factor=2),

            ConvBlock(128, [64, 64, 64], kernel_size=3, stride=1),
            IdentityBlock(64, [64, 64, 64], kernel_size=3, stride=1),
            nn.UpsamplingBilinear2d(scale_factor=2),

            # conv2
            ConvBlock(64, [16, 16, self.out_dim], kernel_size=3, stride=1,
                      final_relu=False),
            IdentityBlock(self.out_dim, [16, 16, self.out_dim],
                          kernel_size=3, stride=1, final_relu=False)
        )
        return layers

    def forward(self, x):
        out = self.layers(x)
        return out
